-
Couldn't load subscription status.
- Fork 6.4k
[Tests] clean up and refactor gradient checkpointing tests #9494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| @unittest.skip( | ||
| "Gradient checkpointing is supported but this test doesn't apply to this class because it's forward is a bit different from the rest." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other forwards don't apply scaling and unscaling like it does. So, that makes this a little different.
|
@yiyixuxu WDYT? |
|
@DN6 could you give this a look? |
|
@DN6 a gentle ping here. |
|
Failing test is unrelated. |
* check. * fixes * fixes * updates * fixes * fixes
* check. * fixes * fixes * updates * fixes * fixes
What does this PR do?
Gradient checkpointing is an essential component of model training. We need to ensure it's implemented properly.
If we had them implemented properly we could have avoided the fix from #9489 beforehand. Another related issue: #9496.
Additionally, I took the liberty of properly skipping the related tests with "unittest.skip". This way, we know that tests are being actually skipped.